/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr.search;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.ResourceLoader;
import org.apache.lucene.util.ResourceLoaderAware;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.FeatureLogger;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.interleaving.Interleaving;
import org.apache.solr.ltr.interleaving.LTRInterleavingQuery;
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.ManagedModelStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.search.QParser;
import org.apache.solr.search.QParserPlugin;
import org.apache.solr.search.SyntaxError;
import org.apache.solr.util.SolrPluginUtils;

/**
 * Plug into solr a rerank model.
 *
 * <p>Learning to Rank Query Parser Syntax: rq={!ltr model=6029760550880411648 reRankDocs=300
 * efi.myCompanyQueryIntent=0.98}
 */
public class LTRQParserPlugin extends QParserPlugin
    implements ResourceLoaderAware, ManagedResourceObserver {
  public static final String NAME = "ltr";
  private static final String ORIGINAL_RANKING = "_OriginalRanking_";

  // params for setting custom external info that features can use, like query
  // intent
  static final String EXTERNAL_FEATURE_INFO = "efi.";

  private ManagedFeatureStore fr = null;
  private ManagedModelStore mr = null;

  private LTRThreadModule threadManager = null;

  /** query parser plugin: the name of the attribute for setting the model */
  public static final String MODEL = "model";

  /** query parser plugin: default number of documents to rerank */
  public static final int DEFAULT_RERANK_DOCS = 200;

  /** query parser plugin:the param that will select how the number of document to rerank */
  public static final String RERANK_DOCS = "reRankDocs";

  /** query parser plugin: default interleaving algorithm */
  public static final String DEFAULT_INTERLEAVING_ALGORITHM = Interleaving.TEAM_DRAFT;

  /** query parser plugin:the param that selects the interleaving algorithm to use */
  public static final String INTERLEAVING_ALGORITHM = "interleavingAlgorithm";

  @Override
  public void init(NamedList<?> args) {
    super.init(args);
    threadManager = LTRThreadModule.getInstance(args);
    SolrPluginUtils.invokeSetters(this, args);
  }

  @Override
  public QParser createParser(
      String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
    return new LTRQParser(qstr, localParams, params, req);
  }

  /**
   * Given a set of local SolrParams, extract all of the efi.key=value params into a map
   *
   * @param localParams Local request parameters that might conatin efi params
   * @return Map of efi params, where the key is the name of the efi param, and the value is the
   *     value of the efi param
   */
  public static Map<String, String[]> extractEFIParams(SolrParams localParams) {
    final Map<String, String[]> externalFeatureInfo = new HashMap<>();
    for (final Iterator<String> it = localParams.getParameterNamesIterator(); it.hasNext(); ) {
      final String name = it.next();
      if (name.startsWith(EXTERNAL_FEATURE_INFO)) {
        externalFeatureInfo.put(
            name.substring(EXTERNAL_FEATURE_INFO.length()), new String[] {localParams.get(name)});
      }
    }
    return externalFeatureInfo;
  }

  @Override
  public void inform(ResourceLoader loader) throws IOException {
    final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader;
    ManagedFeatureStore.registerManagedFeatureStore(solrResourceLoader, this);
    ManagedModelStore.registerManagedModelStore(solrResourceLoader, this);
  }

  @Override
  public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
      throws SolrException {
    if (res instanceof ManagedFeatureStore) {
      fr = (ManagedFeatureStore) res;
    }
    if (res instanceof ManagedModelStore) {
      mr = (ManagedModelStore) res;
    }
    if (mr != null && fr != null) {
      mr.setManagedFeatureStore(fr);
      // now we can safely load the models
      mr.loadStoredModels();
    }
  }

  public class LTRQParser extends QParser {

    public LTRQParser(
        String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
      super(qstr, localParams, params, req);
    }

    @Override
    public Query parse() throws SyntaxError {
      if (threadManager != null) {
        threadManager.setExecutor(
            req.getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
      }
      // ReRanking Model
      final String[] modelNames = localParams.getParams(LTRQParserPlugin.MODEL);
      if ((modelNames == null) || (modelNames.length != 1 && modelNames.length != 2)) {
        throw new SolrException(
            SolrException.ErrorCode.BAD_REQUEST, "Must provide one or two models in the request");
      }
      final boolean isInterleaving = (modelNames.length > 1);
      final boolean isLoggingFeatures = SolrQueryRequestContextUtils.isLoggingFeatures(req);

      final Map<String, String[]> externalFeatureInfo = extractEFIParams(localParams);

      LTRScoringQuery rerankingQuery = null;
      LTRInterleavingScoringQuery[] rerankingQueries =
          new LTRInterleavingScoringQuery[modelNames.length];
      for (int i = 0; i < modelNames.length; i++) {
        if (modelNames[i].isEmpty()) {
          throw new SolrException(
              SolrException.ErrorCode.BAD_REQUEST,
              "the " + LTRQParserPlugin.MODEL + " " + i + " is empty");
        }
        if (!ORIGINAL_RANKING.equals(modelNames[i])) {
          final LTRScoringModel ltrScoringModel = mr.getModel(modelNames[i]);
          if (ltrScoringModel == null) {
            throw new SolrException(
                SolrException.ErrorCode.BAD_REQUEST,
                "cannot find " + LTRQParserPlugin.MODEL + " " + modelNames[i]);
          }

          if (isInterleaving) {
            rerankingQuery =
                rerankingQueries[i] =
                    new LTRInterleavingScoringQuery(
                        ltrScoringModel, externalFeatureInfo, threadManager);
          } else {
            rerankingQuery =
                new LTRScoringQuery(ltrScoringModel, externalFeatureInfo, threadManager);
            rerankingQueries[i] = null;
          }

          if (isLoggingFeatures) {
            FeatureLogger featureLogger = SolrQueryRequestContextUtils.getFeatureLogger(req);
            final String modelFeatureStore = ltrScoringModel.getFeatureStoreName();
            final String loggerFeatureStore = SolrQueryRequestContextUtils.getFvStoreName(req);
            final boolean isSameFeatureStore =
                (modelFeatureStore.equals(loggerFeatureStore) || loggerFeatureStore == null);

            if (isSameFeatureStore) {
              if (featureLogger.isLoggingAll() == null) {
                featureLogger.setLogAll(false); // default to log only model features
              }
              rerankingQuery.setFeatureLogger(featureLogger);
            } else {
              if (featureLogger.isLoggingAll() == null) {
                featureLogger.setLogAll(true); // default to log all features from the store
              }
              if (!featureLogger.isLoggingAll()) {
                throw new SolrException(
                    SolrException.ErrorCode.BAD_REQUEST,
                    "the feature store '"
                        + loggerFeatureStore
                        + "' in the logger is different from the model feature store '"
                        + modelFeatureStore
                        + "', you can only log all the features from the store");
              }
            }
          }
        } else {
          rerankingQuery =
              rerankingQueries[i] = new OriginalRankingLTRScoringQuery(ORIGINAL_RANKING);
        }

        // External features
        rerankingQuery.setRequest(req);
      }

      int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS);
      if (reRankDocs <= 0) {
        throw new SolrException(
            SolrException.ErrorCode.BAD_REQUEST, "Must rerank at least 1 document");
      }
      if (!isInterleaving) {
        SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] {rerankingQuery});
        return new LTRQuery(rerankingQuery, reRankDocs);
      } else {
        String interleavingAlgorithm =
            localParams.get(INTERLEAVING_ALGORITHM, DEFAULT_INTERLEAVING_ALGORITHM);
        SolrQueryRequestContextUtils.setScoringQueries(req, rerankingQueries);
        return new LTRInterleavingQuery(
            Interleaving.getImplementation(interleavingAlgorithm), rerankingQueries, reRankDocs);
      }
    }
  }
}
